import torch

from .interface import Loss


class CrossEntropy(Loss):
    def __init__(self):
        super().__init__()
        self.__name__ = 'cross-entropy-loss'

    @staticmethod
    def forward(pr: torch.Tensor, gt: torch.Tensor, *args: object) -> torch.Tensor:
        loss = gt * torch.log(pr + 1e-17)
        loss = -loss.mean()
        return loss
